# ========================================================================= #
# Project: Lexical Ambiguity in Political Rhetoric (BJPolS)
# - Script: Validate sentence embeddings
# - Author: Patrick Kraft (patrickwilli.kraft@uc3m.es)
# ========================================================================= #


# Load packages and custom functions --------------------------------------

source(here::here("code/00-func.R"))


# Prepare sentence embeddings ---------------------------------------------

df <- read_csv(here("out/sentences.csv")) %>%
  bind_cols(read_csv(here("out/embeddings.csv"), col_names = FALSE)) %>%
  bind_cols(
    corpus(.$sentence) %>%
      tokens() %>% 
      tokens_lookup(dictionary = dictionary(file=here("in/mfd2.0.dic"), format="LIWC")) %>%
      dfm() %>% 
      convert("data.frame") %>%
      transmute(Care = (care.virtue + care.vice)>0, 
                Fairness = (fairness.virtue + fairness.vice)>0, 
                Loyalty = (loyalty.virtue + loyalty.vice)>0, 
                Authority = (authority.virtue + authority.vice)>0, 
                Sanctity = (sanctity.virtue + sanctity.vice)>0)
  ) %>%
  filter((Care + Fairness + Loyalty + Authority + Sanctity)==1 &
           party != "Liberal") %>%
  mutate(Care = if_else(Care, "Care", ""), 
         Fairness = if_else(Fairness, "Fairness", ""), 
         Loyalty = if_else(Loyalty, "Loyalty", ""), 
         Authority = if_else(Authority, "Authority", ""), 
         Sanctity = if_else(Sanctity, "Sanctity", ""),
         mft = factor(paste0(Care, Fairness, Loyalty, Authority, Sanctity))) %>%
  select(sentence, mft, starts_with("X"))


# Partition dataframe for 5-fold cross-validation -------------------------

set.seed(42)
inTrain <- createDataPartition(
  y = df$mft,
  p = .75,
  list = FALSE)

training <- data.frame(select(df, -sentence))[inTrain, ]
testing  <- data.frame(select(df, -sentence))[-inTrain, ]
ctrl <- trainControl(method = "repeatedcv", repeats = 5)


# Train k-nearest neighbors -----------------------------------------------

knnFit <- train(
  mft ~ .,
  data = training,
  method = "knn",
  tuneLength = 10,
  trControl = ctrl
)


# Check out-of-sample predictions -----------------------------------------

knnClasses <- predict(knnFit, newdata = testing)
confmat <- confusionMatrix(knnClasses, testing$mft, mode = "prec_recall")
confmat$overall <- round(confmat$overall, 3)

## Print overall accuracy in table
tribble(
  ~Measure, ~Value,
  "Accuracy", as.character(confmat$overall["Accuracy"]),
  "95% CI", paste0("(",confmat$overall["AccuracyLower"],", ",
                   confmat$overall["AccuracyUpper"],")"),
  "No Information Rate", as.character(confmat$overall["AccuracyNull"]),
  "P-Value [Acc > NIR]", as.vector(ifelse(confmat$overall["AccuracyPValue"] == 0,"< 0.001", 
                                          as.character(confmat$overall["AccuracyPValue"])))
) %>%
  xtable(caption = "Out-of-sample accuracy predicting moral foundations based on sentence embeddings
  using k-nearest neighbors. Model was trained on 75\\% of available speeches using 5-fold cross-validation.",
         label = "tab:accuracy",
         align = "lrc") %>%
  print(file = here("out/appA1-accuracy.tex"),
        table.placement = "ht",
        hline.after = c(0,nrow(.)),
        include.rownames = FALSE,
        include.colnames = FALSE)

## Plot precision/recall/f1 by class
confmat$byClass %>%
  as_tibble() %>%
  select(Precision, Recall, F1) %>%
  mutate(foundation = gsub("Class: ", "", rownames(confmat$byClass))) %>%
  pivot_longer(-foundation, names_to = "Measure", values_to = "Value") %>%
  mutate(foundation = factor(foundation, 
                             levels = c("Care","Fairness",
                                        "Loyalty","Authority","Sanctity")),
         Measure = factor(Measure, levels = c("Precision","Recall","F1"))) %>%
  ggplot(aes(x = foundation, y = Value, fill = Measure)) +
  geom_bar(stat = "identity", position = "dodge") +
  scale_fill_brewer(palette = "Dark2") +
  ylim(0,1) +
  labs(x = NULL, y = NULL) +
  theme_mft()
ggsave(here("out/appA2-precision.png"), height = 3, width = 5, dpi = 600)
